{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61862dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8033ce8",
   "metadata": {},
   "source": [
    "# Dynamic In-Context Learning\n",
    "\n",
    "This recipe allows TensorZero users to set up a dynamic in-context learning variant for any function.\n",
    "Since TensorZero automatically logs all inferences and feedback, it is straightforward to query a set of good examples and retrieve the most relevant ones to put them into context for future inferences.\n",
    "Since TensorZero allows users to add demonstrations for any inference it is also easy to include them in the set of examples as well.\n",
    "This recipe will show use the OpenAI embeddings API only, but we have support for other embeddings providers as well.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b6ecd77",
   "metadata": {},
   "source": [
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL=\"http://chuser:chpassword@localhost:8123/tensorzero\"`\n",
    "- Set the `OPENAI_API_KEY` environment variable.\n",
    "- Update the following parameters\n",
    "- Uncomment query filters as appropriate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e136f79c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "\n",
    "CONFIG_PATH = \"../../examples/data-extraction-ner/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "METRIC_NAME: Optional[str] = None\n",
    "\n",
    "MAX_EXAMPLES = 1000\n",
    "\n",
    "# The name of the DICL variant you will want to use. Set this to a meaningful name that does not conflict\n",
    "# with other variants for the function selected above.\n",
    "DICL_VARIANT_NAME = \"gpt_4o_mini_dicl\"\n",
    "\n",
    "# The model to use for the DICL variant. Should match the name of the embedding model defined in your config\n",
    "DICL_EMBEDDING_MODEL = \"openai::text-embedding-3-small\"\n",
    "\n",
    "# The model to use for generation in the DICL variant\n",
    "DICL_GENERATION_MODEL = \"openai::gpt-4o-2024-08-06\"\n",
    "\n",
    "# The number of examples to retrieve for the DICL variant\n",
    "DICL_K = 10\n",
    "\n",
    "# If the metric is a float metric, you can set the threshold to filter the data\n",
    "FLOAT_METRIC_THRESHOLD = 0.5\n",
    "\n",
    "# Whether to use demonstrations for DICL examples\n",
    "USE_DEMONSTRATIONS = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ea49b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from asyncio import Semaphore\n",
    "\n",
    "import pandas as pd\n",
    "import toml\n",
    "from clickhouse_connect import get_client\n",
    "from openai import AsyncOpenAI\n",
    "from tensorzero import TensorZeroGateway, patch_openai_client\n",
    "from tensorzero.util import uuid7\n",
    "from tqdm.asyncio import tqdm_asyncio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91be3cd0",
   "metadata": {},
   "source": [
    "Initialize the ClickHouse client.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64acef04",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert \"TENSORZERO_CLICKHOUSE_URL\" in os.environ, \"TENSORZERO_CLICKHOUSE_URL environment variable not set\"\n",
    "\n",
    "clickhouse_client = get_client(dsn=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f053ba2",
   "metadata": {},
   "source": [
    "Initialize the TensorZero Client\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9755a248",
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = TensorZeroGateway.build_embedded(clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"], config_file=CONFIG_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ah6b39u8thu",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai_client = await patch_openai_client(\n",
    "    AsyncOpenAI(),\n",
    "    clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"],\n",
    "    config_file=CONFIG_PATH,\n",
    "    async_setup=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb667f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "filters = None\n",
    "# To filter on a boolean metric, you can uncomment the following line\n",
    "# filters = BooleanMetricFilter(metric_name=METRIC_NAME, value=True) # or False as needed\n",
    "\n",
    "# To filter on a float metric, you can uncomment the following line\n",
    "# filters = FloatMetricFilter(metric_name=METRIC_NAME, value=0.5, comparison_operator=\">\")\n",
    "# or any other float value as needed\n",
    "# You can even use AND, OR, and NOT operators to combine multiple filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "778352db",
   "metadata": {},
   "outputs": [],
   "source": [
    "inferences = t0.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    filters=filters,\n",
    "    output_source=\"demonstration\",\n",
    "    # or \"inference\" if you don't want to use (or don't have) demonstrations\n",
    "    # if you use \"demonstration\" we will restrict to the subset of infereences\n",
    "    # that have demonstrations\n",
    "    limit=MAX_EXAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb9f7db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def get_embedding(\n",
    "    text: str, semaphore: Semaphore, model: str = \"text-embedding-3-small\"\n",
    ") -> Optional[list[float]]:\n",
    "    try:\n",
    "        async with semaphore:\n",
    "            response = await openai_client.embeddings.create(\n",
    "                input=text, model=f\"tensorzero::embedding_model_name::{model}\"\n",
    "            )\n",
    "            return response.data[0].embedding\n",
    "    except Exception as e:\n",
    "        print(f\"Error getting embedding: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a610f745",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_CONCURRENT_EMBEDDING_REQUESTS = 50\n",
    "semaphore = Semaphore(MAX_CONCURRENT_EMBEDDING_REQUESTS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53bf6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed the 'input' column using the get_embedding function\n",
    "tasks = [get_embedding(str(inference.input), semaphore, DICL_EMBEDDING_MODEL) for inference in inferences]\n",
    "embeddings = await tqdm_asyncio.gather(*tasks, desc=\"Embedding inputs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245d25f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for inference, embedding in zip(inferences, embeddings):\n",
    "    data.append(\n",
    "        {\n",
    "            \"input\": str(inference.input),\n",
    "            \"output\": str(inference.output),\n",
    "            \"embedding\": embedding,\n",
    "            \"function_name\": FUNCTION_NAME,\n",
    "            \"variant_name\": DICL_VARIANT_NAME,\n",
    "            \"id\": uuid7(),\n",
    "        }\n",
    "    )\n",
    "example_df = pd.DataFrame(data)\n",
    "example_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "142a5284",
   "metadata": {},
   "source": [
    "Prepare the data for the DynamicInContextLearningExample table\n",
    "The table schema is as follows:\n",
    "\n",
    "```\n",
    "CREATE TABLE tensorzero.DynamicInContextLearningExample\n",
    "(\n",
    "    `id` UUID,\n",
    "    `function_name` LowCardinality(String),\n",
    "    `variant_name` LowCardinality(String),\n",
    "    `namespace` String,\n",
    "    `input` String,\n",
    "    `output` String,\n",
    "    `embedding` Array(Float32),\n",
    "    `timestamp` DateTime MATERIALIZED UUIDv7ToDateTime(id)\n",
    ")\n",
    "ENGINE = MergeTree\n",
    "ORDER BY (function_name, variant_name, namespace)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edd73a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Insert the data into the DiclExample table\n",
    "result = clickhouse_client.insert_df(\n",
    "    \"DynamicInContextLearningExample\",\n",
    "    example_df,\n",
    ")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9c3e195",
   "metadata": {},
   "source": [
    "Finally, add a new variant to your function configuration to try out the Dynamic In-Context Learning variant in practice!\n",
    "\n",
    "If your embedding model name or generation model name in the config is different from the one you used above, you might have to update the config.\n",
    "Be sure and also give the variant some weight and if you are using a JSON function set the json_mode field to \"strict\" if you want.\n",
    "\n",
    "> **Tip:** DICL variants support additional parameters like system instructions or strict JSON mode. See [Configuration Reference](https://www.tensorzero.com/docs/gateway/configuration-reference).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aa1e298",
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_config = {\n",
    "    \"type\": \"experimental_dynamic_in_context_learning\",\n",
    "    \"embedding_model\": DICL_EMBEDDING_MODEL,\n",
    "    \"model\": DICL_GENERATION_MODEL,\n",
    "    \"k\": DICL_K,\n",
    "}\n",
    "full_variant_config = {\"functions\": {FUNCTION_NAME: {\"variants\": {DICL_VARIANT_NAME: variant_config}}}}\n",
    "\n",
    "print(toml.dumps(full_variant_config))"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
